{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "import pandas as pd\n",
    "from tensorflow.keras import datasets, layers, models\n",
    "import efficientnet.tfkeras as efn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 文件存放位置\n",
    "# data3>train>0\n",
    "# data3>train>1\n",
    "# data3>train>2\n",
    "# data3>train>3\n",
    "# data3>test>测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('train.csv')\n",
    "df = df.set_index('image')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_path = glob.glob('.\\\\train\\\\*.jpg')\n",
    "labels = [df.loc[i.split('\\\\')[-1]][0]-1 for i in train_image_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step1 使用 tf.data.Dataset.from_tensor_slices 进行加载\n",
    "image_ds = tf.data.Dataset.from_tensor_slices((train_image_path,labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_count = int(len(labels)*0.2)\n",
    "train_count = len(labels)-val_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprocess_image(path,label):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[300,300])\n",
    "    image = tf.image.random_crop(image,[260,260,3])\n",
    "    image = tf.image.random_flip_left_right(image)\n",
    "    image = tf.image.random_flip_up_down(image)\n",
    "    image = tf.image.random_brightness(image,0.2)\n",
    "    \n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    label = tf.reshape(label,[1])\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprocess_image_test(path,label):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[260,260])\n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    label = tf.reshape(label,[1])\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选取验证集\n",
    "image_train_ds = image_ds.skip(val_count)\n",
    "image_val_ds = image_ds.take(val_count)\n",
    "\n",
    "# Step3 预处理 (预处理函数在下面)\n",
    "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
    "image_train_ds = image_train_ds.map(load_preprocess_image,num_parallel_calls=AUTOTUNE)\n",
    "image_val_ds = image_val_ds.map(load_preprocess_image_test,num_parallel_calls=AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "# shuffle 打乱数据 batch设置 batch size repeat设置迭代次数(迭代2次) test数据集不需要\n",
    "image_train_ds = image_train_ds.repeat(20).shuffle(train_count).batch(BATCH_SIZE)\n",
    "image_train_ds = image_train_ds.prefetch(AUTOTUNE)#预取,GPU，CPU加速\n",
    "\n",
    "image_val_ds = image_val_ds.batch(BATCH_SIZE)\n",
    "image_val_ds = image_val_ds.prefetch(AUTOTUNE)#预取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential()\n",
    "model.add(covn_base)\n",
    "model.add(keras.layers.Dropout(0.2))\n",
    "model.add(keras.layers.Dense(1024,activation='relu'))\n",
    "model.add(keras.layers.Dropout(0.2))\n",
    "model.add(keras.layers.Dense(512,activation='relu'))\n",
    "model.add(keras.layers.Dropout(0.2))\n",
    "model.add(keras.layers.Dense(512,activation='relu'))\n",
    "model.add(keras.layers.Dense(8,activation='sigmoid'))\n",
    "\n",
    "covn_base.trainable = False #设置权重参数不可训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "8/8 [==============================] - 71s 9s/step - loss: 1.9182 - acc: 0.2852 - val_loss: 1.9199 - val_acc: 0.3125\n",
      "Epoch 2/8\n",
      "8/8 [==============================] - 84s 11s/step - loss: 1.8865 - acc: 0.3633 - val_loss: 1.9110 - val_acc: 0.3125\n",
      "Epoch 3/8\n",
      "8/8 [==============================] - 72s 9s/step - loss: 1.8530 - acc: 0.3438 - val_loss: 1.8275 - val_acc: 0.3125\n",
      "Epoch 4/8\n",
      "8/8 [==============================] - 71s 9s/step - loss: 1.7756 - acc: 0.3164 - val_loss: 1.8526 - val_acc: 0.3125\n",
      "Epoch 5/8\n",
      "8/8 [==============================] - 70s 9s/step - loss: 1.6548 - acc: 0.3906 - val_loss: 1.8291 - val_acc: 0.3125\n",
      "Epoch 6/8\n",
      "8/8 [==============================] - 67s 8s/step - loss: 1.6167 - acc: 0.3281 - val_loss: 1.8404 - val_acc: 0.3281\n",
      "Epoch 7/8\n",
      "8/8 [==============================] - 67s 8s/step - loss: 1.5572 - acc: 0.3633 - val_loss: 1.7978 - val_acc: 0.3438\n",
      "Epoch 8/8\n",
      "3/8 [==========>...................] - ETA: 27s - loss: 1.5977 - acc: 0.3646"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-17-5351853e1569>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[1;31m#训练\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m history = model.fit(\n\u001b[0m\u001b[0;32m      8\u001b[0m     \u001b[0mimage_train_ds\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m     \u001b[0msteps_per_epoch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrain_count\u001b[0m\u001b[1;33m//\u001b[0m\u001b[0mBATCH_SIZE\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\keras\\engine\\training.py\u001b[0m in \u001b[0;36m_method_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m    106\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_method_wrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    107\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_in_multi_worker_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 108\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    109\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    110\u001b[0m     \u001b[1;31m# Running inside `run_distribute_coordinator` already.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m   1096\u001b[0m                 batch_size=batch_size):\n\u001b[0;32m   1097\u001b[0m               \u001b[0mcallbacks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mon_train_batch_begin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1098\u001b[1;33m               \u001b[0mtmp_logs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1099\u001b[0m               \u001b[1;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1100\u001b[0m                 \u001b[0mcontext\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    778\u001b[0m       \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    779\u001b[0m         \u001b[0mcompiler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"nonXla\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 780\u001b[1;33m         \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    781\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    782\u001b[0m       \u001b[0mnew_tracing_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    805\u001b[0m       \u001b[1;31m# In this case we have created variables on the first call, so we run the\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    806\u001b[0m       \u001b[1;31m# defunned version which is guaranteed to never create variables.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 807\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# pylint: disable=not-callable\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    808\u001b[0m     \u001b[1;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    809\u001b[0m       \u001b[1;31m# Release the lock early so that multiple threads can perform the call\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   2827\u001b[0m     \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2828\u001b[0m       \u001b[0mgraph_function\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2829\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_filtered_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2830\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2831\u001b[0m   \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_filtered_call\u001b[1;34m(self, args, kwargs, cancellation_manager)\u001b[0m\n\u001b[0;32m   1841\u001b[0m       \u001b[0;31m`\u001b[0m\u001b[0margs\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1842\u001b[0m     \"\"\"\n\u001b[1;32m-> 1843\u001b[1;33m     return self._call_flat(\n\u001b[0m\u001b[0;32m   1844\u001b[0m         [t for t in nest.flatten((args, kwargs), expand_composites=True)\n\u001b[0;32m   1845\u001b[0m          if isinstance(t, (ops.Tensor,\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[1;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[0;32m   1921\u001b[0m         and executing_eagerly):\n\u001b[0;32m   1922\u001b[0m       \u001b[1;31m# No tape is watching; skip to running the function.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1923\u001b[1;33m       return self._build_call_outputs(self._inference_function.call(\n\u001b[0m\u001b[0;32m   1924\u001b[0m           ctx, args, cancellation_manager=cancellation_manager))\n\u001b[0;32m   1925\u001b[0m     forward_backward = self._select_forward_and_backward_functions(\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36mcall\u001b[1;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[0;32m    543\u001b[0m       \u001b[1;32mwith\u001b[0m \u001b[0m_InterpolateFunctionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    544\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mcancellation_manager\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 545\u001b[1;33m           outputs = execute.execute(\n\u001b[0m\u001b[0;32m    546\u001b[0m               \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    547\u001b[0m               \u001b[0mnum_outputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_outputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[1;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[0;32m     57\u001b[0m   \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     58\u001b[0m     \u001b[0mctx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 59\u001b[1;33m     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0m\u001b[0;32m     60\u001b[0m                                         inputs, attrs, num_outputs)\n\u001b[0;32m     61\u001b[0m   \u001b[1;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#编译\n",
    "model.compile(optimizer=keras.optimizers.Adam(lr=0.001),\n",
    "             loss = 'sparse_categorical_crossentropy',\n",
    "             metrics=['acc'])\n",
    "\n",
    "#训练\n",
    "history = model.fit(\n",
    "    image_train_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    epochs=8,\n",
    "    validation_data=image_val_ds,\n",
    "    validation_steps=val_count//BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/16\n",
      "8/8 [==============================] - 39s 5s/step - loss: 1.6745 - acc: 0.3242 - val_loss: 1.8625 - val_acc: 0.3125\n",
      "Epoch 10/16\n",
      "8/8 [==============================] - 36s 4s/step - loss: 1.6012 - acc: 0.3750 - val_loss: 1.8226 - val_acc: 0.3125\n",
      "Epoch 11/16\n",
      "1/8 [==>...........................] - ETA: 0s - loss: 1.4450 - acc: 0.3000WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 64 batches). You may need to use the repeat() function when building your dataset.\n",
      "1/8 [==>...........................] - 8s 8s/step - loss: 1.4450 - acc: 0.3000 - val_loss: 1.8243 - val_acc: 0.3125\n"
     ]
    }
   ],
   "source": [
    "covn_base.trainable=True\n",
    "\n",
    "fine_tune_at = -3\n",
    "for layer in covn_base.layers[:fine_tune_at]:\n",
    "    layer.trainable = False\n",
    "\n",
    "model.compile(optimizer=keras.optimizers.Adam(0.0005),\n",
    "             loss = 'sparse_categorical_crossentropy',\n",
    "             metrics=['acc'])\n",
    "\n",
    "initial_epochs=8\n",
    "fine_tune_epochs=8\n",
    "total_epoch = initial_epochs+fine_tune_epochs\n",
    "\n",
    "#二次训练\n",
    "history = model.fit(\n",
    "    image_train_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    epochs=total_epoch,\n",
    "    initial_epoch=initial_epochs,\n",
    "    validation_data=image_val_ds,\n",
    "    validation_steps=val_count//BATCH_SIZE\n",
    ")\n",
    "\n",
    "model.save('mstz_model_EfficientNetB52.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "#加载图片\n",
    "def load_preprocess_images(path):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[260,260])\n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    image = tf.expand_dims(image,0)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "......................................................................................................................................................"
     ]
    }
   ],
   "source": [
    "test_image_path = glob.glob('test/testA/*.jpg')\n",
    "\n",
    "images = [load_preprocess_images(image_path) for image_path in test_image_path]\n",
    "\n",
    "image_count = len(images)\n",
    "values = []\n",
    "result_dict = {}\n",
    "\n",
    "for i in range(image_count):\n",
    "    pred = model.predict(images[i])\n",
    "    values.append(np.argmax(pred))\n",
    "#     if pred[0][0]>0.25:\n",
    "#         values.append(1)\n",
    "#     else:\n",
    "#         values.append(np.argmax(pred[0][1:])+2)\n",
    "    print('.',end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "#写文件\n",
    "with open('sub.csv','w',encoding='utf-8') as f:\n",
    "    f.write('image_id,category_id\\n')\n",
    "    [f.write('{0},{1}\\n'.format(key.split('\\\\')[1][:-4]+'.jpg', value+1)) for (key,value) in zip(test_image_path,values)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
